{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b518b04cbfe0"
      },
      "source": [
        "##### Copyright 2020 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "906e07f6e562"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c359f002e834"
      },
      "source": [
        "# Training Keras models with TensorFlow Cloud"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f5c893a15fac"
      },
      "source": [
        "<table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://www.tensorflow.org/guide/keras/training_keras_models_on_cloud\"><img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" />View on TensorFlow.org</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/docs/blob/snapshot-keras/site/en/guide/keras/training_keras_models_on_cloud.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/keras-team/keras-io/blob/master/guides/training_keras_models_on_cloud.py\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a href=\"https://storage.googleapis.com/tensorflow_docs/docs/site/en/guide/keras/training_keras_models_on_cloud.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "552fa1053c7b"
      },
      "source": [
        "## Introduction\n",
        "\n",
        "TensorFlow Cloud is a library that makes it easier to do training and\n",
        "hyperparameter tuning of Keras models on Google Cloud.\n",
        "\n",
        "Using TensorFlow Cloud's `run` API, you can send your model code directly to\n",
        "your Google Cloud account, and use Google Cloud compute resources without\n",
        "needing to login and interact with the Cloud UI (once you have set up your\n",
        "project in the console).\n",
        "\n",
        "This means that you can use your Google Cloud compute resources from inside\n",
        "directly a Python notebook: a notebook just like this one! You can also send\n",
        "models to Google Cloud from a plain `.py` Python script.\n",
        "\n",
        "## Simple example\n",
        "\n",
        "This is a simple introductory example to demonstrate how to train a model\n",
        "remotely using [TensorFlow Cloud](https://tensorflow.org/cloud) and Google\n",
        "Cloud.\n",
        "\n",
        "You can just read through it to get an idea of how this works, or you can run\n",
        "the notebook in Google Colab. Running the notebook requires connecting to a\n",
        "Google Cloud account and entering your credentials and project ID. See\n",
        "[Setting Up and Connecting To Your Google Cloud Account](https://github.com/tensorflow/cloud/blob/master/g3doc/tutorials/google_cloud_project_setup_instructions.ipynb)\n",
        "if you don't have an account yet or are not sure how to set up a project in the\n",
        "console."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "3768731d025f"
      },
      "source": [
        "## Import required modules\n",
        "\n",
        "This guide requires TensorFlow Cloud, which you can install via:\n",
        "\n",
        "`pip install tensorflow-cloud`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "2dee4b3b0f30"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import sys\n",
        "import tensorflow as tf\n",
        "import tensorflow_cloud as tfc"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "e264e24fd8f1"
      },
      "source": [
        "## Project Configurations\n",
        "\n",
        "Set project parameters. If you don't know what your `GCP_PROJECT_ID` or\n",
        "`GCS_BUCKET` should be, see\n",
        "[Setting Up and Connecting To Your Google Cloud Account](https://github.com/tensorflow/cloud/blob/master/g3doc/tutorials/google_cloud_project_setup_instructions.ipynb).\n",
        "\n",
        "The `JOB_NAME` is optional, and you can set it to any string. If you are doing\n",
        "multiple training experiemnts (for example) as part of a larger project, you may\n",
        "want to give each of them a unique `JOB_NAME`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "31dbaa3f9b16"
      },
      "outputs": [],
      "source": [
        "# Set Google Cloud Specific parameters\n",
        "\n",
        "# TODO: Please set GCP_PROJECT_ID to your own Google Cloud project ID.\n",
        "GCP_PROJECT_ID = \"YOUR_PROJECT_ID\"  # @param {type:\"string\"}\n",
        "\n",
        "# TODO: set GCS_BUCKET to your own Google Cloud Storage (GCS) bucket.\n",
        "GCS_BUCKET = \"YOUR_GCS_BUCKET_NAME\"  # @param {type:\"string\"}\n",
        "\n",
        "# DO NOT CHANGE: Currently only the 'us-central1' region is supported.\n",
        "REGION = \"us-central1\"\n",
        "\n",
        "# OPTIONAL: You can change the job name to any string.\n",
        "JOB_NAME = \"mnist\"  # @param {type:\"string\"}\n",
        "\n",
        "# Setting location were training logs and checkpoints will be stored\n",
        "GCS_BASE_PATH = f\"gs://{GCS_BUCKET}/{JOB_NAME}\"\n",
        "TENSORBOARD_LOGS_DIR = os.path.join(GCS_BASE_PATH, \"logs\")\n",
        "MODEL_CHECKPOINT_DIR = os.path.join(GCS_BASE_PATH, \"checkpoints\")\n",
        "SAVED_MODEL_DIR = os.path.join(GCS_BASE_PATH, \"saved_model\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eed0e648608a"
      },
      "source": [
        "## Authenticating the notebook to use your Google Cloud Project\n",
        "\n",
        "This code authenticates the notebook, checking your valid Google Cloud\n",
        "credentials and identity. It is inside the `if not tfc.remote()` block to ensure\n",
        "that it is only run in the notebook, and will not be run when the notebook code\n",
        "is sent to Google Cloud.\n",
        "\n",
        "Note: For Kaggle Notebooks click on \"Add-ons\"->\"Google Cloud SDK\" before running\n",
        "the cell below."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f0a9300ba7bd"
      },
      "outputs": [],
      "source": [
        "# Using tfc.remote() to ensure this code only runs in notebook\n",
        "if not tfc.remote():\n",
        "\n",
        "    # Authentication for Kaggle Notebooks\n",
        "    if \"kaggle_secrets\" in sys.modules:\n",
        "        from kaggle_secrets import UserSecretsClient\n",
        "\n",
        "        UserSecretsClient().set_gcloud_credentials(project=GCP_PROJECT_ID)\n",
        "\n",
        "    # Authentication for Colab Notebooks\n",
        "    if \"google.colab\" in sys.modules:\n",
        "        from google.colab import auth\n",
        "\n",
        "        auth.authenticate_user()\n",
        "        os.environ[\"GOOGLE_CLOUD_PROJECT\"] = GCP_PROJECT_ID"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "689ecfc05952"
      },
      "source": [
        "## Model and data setup\n",
        "\n",
        "From here we are following the basic procedure for setting up a simple Keras\n",
        "model to run classification on the MNIST dataset.\n",
        "\n",
        "### Load and split data\n",
        "\n",
        "Read raw data and split to train and test data sets."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "c8dbb31a610b"
      },
      "outputs": [],
      "source": [
        "(x_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()\n",
        "\n",
        "x_train = x_train.reshape((60000, 28 * 28))\n",
        "x_train = x_train.astype(\"float32\") / 255"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6fe4bce33f09"
      },
      "source": [
        "### Create a model and prepare for training\n",
        "\n",
        "Create a simple model and set up a few callbacks for it."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "67b17e1debc5"
      },
      "outputs": [],
      "source": [
        "from tensorflow.keras import layers\n",
        "\n",
        "model = tf.keras.Sequential(\n",
        "    [\n",
        "        tf.keras.layers.Dense(512, activation=\"relu\", input_shape=(28 * 28,)),\n",
        "        tf.keras.layers.Dropout(0.2),\n",
        "        tf.keras.layers.Dense(10, activation=\"softmax\"),\n",
        "    ]\n",
        ")\n",
        "\n",
        "model.compile(\n",
        "    loss=\"sparse_categorical_crossentropy\",\n",
        "    optimizer=tf.keras.optimizers.Adam(),\n",
        "    metrics=[\"accuracy\"],\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "694ea19488d1"
      },
      "source": [
        "### Quick validation training\n",
        "\n",
        "We'll train the model for one (1) epoch just to make sure everything is set up\n",
        "correctly, and we'll wrap that training command in `if not` `tfc.remote`, so\n",
        "that it only happens here in the runtime environment in which you are reading\n",
        "this, not when it is sent  to Google Cloud."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "9babd8771221"
      },
      "outputs": [],
      "source": [
        "if not tfc.remote():\n",
        "    # Run the training for 1 epoch and a small subset of the data to validate setup\n",
        "    model.fit(x=x_train[:100], y=y_train[:100], validation_split=0.2, epochs=1)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ae5e255dc88a"
      },
      "source": [
        "## Prepare for remote training\n",
        "\n",
        "The code below will only run when the notebook code is sent to Google Cloud, not\n",
        "inside the runtime in which you are reading this.\n",
        "\n",
        "First, we set up callbacks which will:\n",
        "\n",
        "* Create logs for [TensorBoard](https://www.tensorflow.org/tensorboard).\n",
        "* Create [checkpoints](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/model_checkpoint/) and save them to the checkpoints\n",
        "directory specified above.\n",
        "* Stop model training if loss is not improving sufficiently.\n",
        "\n",
        "Then we call `model.fit` and `model.save`, which (when this code is running on\n",
        "Google Cloud) which actually run the full training (100 epochs) and then save\n",
        "the trained model in the GCS Bucket and directory defined above."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "99c73ab7b7ae"
      },
      "outputs": [],
      "source": [
        "if tfc.remote():\n",
        "    # Configure Tensorboard logs\n",
        "    callbacks = [\n",
        "        tf.keras.callbacks.TensorBoard(log_dir=TENSORBOARD_LOGS_DIR),\n",
        "        tf.keras.callbacks.ModelCheckpoint(MODEL_CHECKPOINT_DIR, save_best_only=True),\n",
        "        tf.keras.callbacks.EarlyStopping(monitor=\"loss\", min_delta=0.001, patience=3),\n",
        "    ]\n",
        "\n",
        "    model.fit(\n",
        "        x=x_train, y=y_train, epochs=100, validation_split=0.2, callbacks=callbacks,\n",
        "    )\n",
        "\n",
        "    model.save(SAVED_MODEL_DIR)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cb8505bf596d"
      },
      "source": [
        "## Start the remote training\n",
        "\n",
        "TensorFlow Cloud takes all the code from its local execution environment (this\n",
        "notebook), wraps it up, and sends it to Google Cloud for execution. (That's why\n",
        "the `if` and `if not` `tfc.remote` wrappers are important.)\n",
        "\n",
        "This step will prepare your code from this notebook for remote execution and\n",
        "then start a remote training job on Google Cloud Platform to train the model.\n",
        "\n",
        "First we add the `tensorflow-cloud` Python package to a `requirements.txt` file,\n",
        "which will be sent along with the code in this notebook. You can add other\n",
        "packages here as needed.\n",
        "\n",
        "Then a GPU and a CPU image are specified. You only need to specify one or the\n",
        "other; the GPU is used in the code that follows.\n",
        "\n",
        "Finally, the heart of TensorFlow cloud: the call to `tfc.run`. When this is\n",
        "executed inside this notebook, all the code from this notebook, and the rest of\n",
        "the files in this directory, will be packaged and sent to Google Cloud for\n",
        "execution. The parameters on the `run` method specify the details of the  GPU\n",
        "CPU images are specified. You only need to specify one or the other; the GPU is\n",
        "used in the code that follows.\n",
        "\n",
        "Finally, the heart of TensorFlow cloud: the call to `tfc.run`. When this is\n",
        "executed inside this notebook, all the code from this notebook, and the rest of\n",
        "the files in this directory, will be packaged and sent to Google Cloud for\n",
        "execution. The parameters on the `run` method specify the details of the  GPU\n",
        "and CPU images are specified. You only need to specify one or the other; the GPU\n",
        "is used in the code that follows.\n",
        "\n",
        "Finally, the heart of TensorFlow cloud: the call to `tfc.run`. When this is\n",
        "executed inside this notebook, all the code from this notebook, and the rest of\n",
        "the files in this directory, will be packaged and sent to Google Cloud for\n",
        "execution. The parameters on the `run` method specify the details of the\n",
        "execution environment and the distribution strategy (if any) to be used.\n",
        "\n",
        "Once the job is submitted you can go to the next step to monitor the jobs\n",
        "progress via Tensorboard."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "14e001767c59"
      },
      "outputs": [],
      "source": [
        "# If you are using a custom image you can install modules via requirements\n",
        "# txt file.\n",
        "with open(\"requirements.txt\", \"w\") as f:\n",
        "    f.write(\"tensorflow-cloud\\n\")\n",
        "\n",
        "# Optional: Some recommended base images. If you provide none the system\n",
        "# will choose one for you.\n",
        "TF_GPU_IMAGE = \"gcr.io/deeplearning-platform-release/tf2-cpu.2-5\"\n",
        "TF_CPU_IMAGE = \"gcr.io/deeplearning-platform-release/tf2-gpu.2-5\"\n",
        "\n",
        "# Submit a single node training job using GPU.\n",
        "tfc.run(\n",
        "    distribution_strategy=\"auto\",\n",
        "    requirements_txt=\"requirements.txt\",\n",
        "    docker_config=tfc.DockerConfig(\n",
        "        parent_image=TF_GPU_IMAGE, image_build_bucket=GCS_BUCKET\n",
        "    ),\n",
        "    chief_config=tfc.COMMON_MACHINE_CONFIGS[\"K80_1X\"],\n",
        "    job_labels={\"job\": JOB_NAME},\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "b1efeedcc31d"
      },
      "source": [
        "## Training Results\n",
        "\n",
        "### Reconnect your Colab instance\n",
        "\n",
        "Most remote training jobs are long running. If you are using Colab, it may time\n",
        "out before the training results are available.\n",
        "\n",
        "In that case, **rerun the following sections in order** to reconnect and\n",
        "configure your Colab instance to access the training results.\n",
        "\n",
        "1.   Import required modules\n",
        "2.   Project Configurations\n",
        "3.   Authenticating the notebook to use your Google Cloud Project\n",
        "\n",
        "**DO NOT** rerun the rest of the code.\n",
        "\n",
        "### Load Tensorboard\n",
        "\n",
        "While the training is in progress you can use Tensorboard to view the results.\n",
        "Note the results will show only after your training has started. This may take a\n",
        "few minutes."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "e8bb043598b8"
      },
      "outputs": [],
      "source": [
        "# Commented out IPython magic to ensure Python compatibility.\n",
        "# %load_ext tensorboard\n",
        "# %tensorboard --logdir $TENSORBOARD_LOGS_DIR"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "607aee95d918"
      },
      "source": [
        "## Load your trained model\n",
        "\n",
        "Once training is complete, you can retrieve your model from the GCS Bucket you\n",
        "specified above."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "74c9146358b0"
      },
      "outputs": [],
      "source": [
        "# trained_model = tf.keras.models.load_model(SAVED_MODEL_DIR)\n",
        "# trained_model.summary()"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "name": "training_keras_models_on_cloud.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
